Learning the Lennard-Jones (LJ) Potential¶

This notebook demonstrates how to learn a pair potential (e.g., Lennard-Jones) from Molecular Dynamics (MD) simulations by optimizing against Radial Distribution Functions (RDFs) and Velocity Auto-Correlation Functions (VACFs).

Key Concepts¶

  • Lennard-Jones Potential: Describes interactions between two particles:

$$ U(r) = 4 \varepsilon \left[ \left( \frac{\sigma}{r} \right)^{12} - \left( \frac{\sigma}{r} \right)^6 \right] $$

  • MD Simulations: Use Nose-Hoover Chain thermostat to evolve the system.
  • Objective: Learn a potential function that reproduces the RDF and VACF of a reference system.

Importing Required Libraries¶

In [77]:
import os
import json
import numpy as np
import matplotlib.pyplot as plt

from ase import Atoms
from ase.lattice.cubic import FaceCenteredCubic, Diamond

from ase.visualize import *
from data_src.data import *
from potential_src.pairMLP.potential_PairMLP import *
from observables.rdf import *
from observables.observers import *   
from utils.get_utils import *
def plot_vacf(vacf_sim, vacf_target, fn, path, dt=0.01, save_data=False):

    t_range = np.linspace(0.0,  vacf_sim.shape[0], vacf_sim.shape[0]) * dt 

    plt.plot(t_range, vacf_sim, label='simulation', linewidth=4, alpha=0.6, )

    if vacf_target is not None:
        plt.plot(t_range, vacf_target, label='target', linewidth=2,linestyle='--', c='black' )

    plt.legend()
    plt.show()

    if save_data:
         np.savetxt(path + '/vacf_{}.txt'.format(fn), np.stack((t_range, vacf_sim)), delimiter=',' )
         np.savetxt(path + '/vacf_{}_target.txt'.format(fn), np.stack((t_range, vacf_target)), delimiter=',' )

    plt.savefig(path + '/vacf_{}.pdf'.format(fn), bbox_inches='tight')
    plt.close()

def plot_rdf( g_sim, rdf_target, fn, path, start, nbins, save_data=False, end=2.5):

    bins = np.linspace(start, end, nbins)

    plt.plot(bins, g_sim , label='simulation', linewidth=4, alpha=0.6)
    plt.plot(bins, rdf_target , label='target', linewidth=2,linestyle='--', c='black')
    
    plt.xlabel("$\AA$")
    plt.ylabel("g(r)")

    if save_data:
        np.savetxt(path + '/rdf_{}.txt'.format(fn), np.stack((bins, g_sim)), delimiter=',' )
        np.savetxt(path + '/rdf_{}_target.txt'.format(fn), np.stack((bins, rdf_target)), delimiter=',' )

    plt.show()
    plt.savefig(path + '/rdf_{}.pdf'.format(fn), bbox_inches='tight')
    plt.close()

def plot_pair(fn, path, model, prior, device, end=2.5, target_pot=None): 

    if target_pot is None:
        target_pot = LennardJones(1.0, 1.0)
    else:
        target_pot = target_pot.to("cpu")

    x = torch.linspace(0.1, end, 250)[:, None].to(device)
    
    u_fit = (model(x) + prior(x)).detach().cpu().numpy()
    u_fit = u_fit - u_fit[-1] 

    u_target = target_pot(x.detach().cpu()).squeeze()

    plt.plot( x.detach().cpu().numpy(), 
              u_fit, 
              label='fit', linewidth=4, alpha=0.6)
    
    plt.plot( x.detach().cpu().numpy(), 
              u_target.detach().cpu().numpy(),
               label='truth', 
               linewidth=2,linestyle='--', c='black')

    plt.ylim(-2, 4.0)
    plt.legend()      
    plt.show()
    plt.savefig(path + '/potential_{}.jpg'.format(fn), bbox_inches='tight')
    plt.close()

    return u_fit

System Definition¶

We first define the simulation parameters, such as density, temperature, and the cutoff radius. We also initialize the MD system using ASE (Atomic Simulation Environment).

In [2]:
pair_data_dict = {  'lj_0.3_1.2': {
            'rdf_fn': '../data/LJ_data/rdf_rho0.3_T1.2_dt0.01.csv' ,
            'vacf_fn': '../data/LJ_data/vacf_rho0.3_T1.2_dt0.01.csv' ,
            'rho': 0.3,
            'T': 1.2, 
            'start': 0.75, 
            'end': 3.3,
            'element': "H",
            'mass': 1.0,
            "N_unitcell": 4,
            "cell": FaceCenteredCubic,
            "target_pot": LennardJones()
            },
  }

width_dict = {'tiny': 64,
               'low': 128,
               'mid': 256, 
               'high': 512}

gaussian_dict = {'tiny': 16,
               'low': 32,
               'mid': 64, 
               'high': 128}

The simulation in this code is self-consistent, meaning all computations are handled within the script without relying on external molecular dynamics software like LAMMPS or GROMACS. Everything—potential evaluation, numerical integration, observables extraction, and model training—is done internally using PyTorch, ASE, and custom simulation modules.

The function 'get_system()' initializes a periodic atomic system in 3D using ASE. We need to define:

  • density
  • temperature
  • lattice type (e.g. FCC)
  • PBCs
In [35]:
def get_system(data_str, device, size):
    # here we define an NVT system setup
    rho = pair_data_dict[data_str]['rho']
    T = pair_data_dict[data_str]['T']
    dim = pair_data_dict[data_str].get("dim", 3)

    if dim == 3:
        # initialize states with ASE 
        cell_module = pair_data_dict[data_str]['cell']
        N_unitcell = pair_data_dict[data_str]['N_unitcell']
        def get_unit_len(rho, N_unitcell):
            L = (N_unitcell / rho) ** (1/3)
            return L 
        L = get_unit_len(rho, N_unitcell)
        print("lattice param:", L)
        atoms = cell_module(symbol=pair_data_dict[data_str]['element'],
                                  size=(size, size, size),
                                  latticeconstant= L,
                                  pbc=True)
        system = System(atoms, device=device)
        system.set_temperature(T)

    return system 

The system uses a ground truth potential (target_pot) from pair_data_dict. This potential can be: -Lennard-Jones (LJ) for simple atomic interactions. -Custom pair potentials for more complex interactions. This potential is implemented in PyTorch, meaning forces and energies are computed dynamically without external MD engines.

The function get_target_obs() sets up and runs the simulation.

In [36]:
def get_target_obs(system, data_str, n_sim, rdf_range, nbins, t_range, dt, skip=25):
    print("simulating {}".format(data_str))
    device = system.device 
    # simulation setup
    target_pot = pair_data_dict[data_str]['target_pot'] # take target potential
    T = pair_data_dict[data_str]['T'] # take target temperature

    pot = PairPotentials(system, target_pot, cutoff=2.5, nbr_list_device=device).to(device)

    diffeq = NoseHooverChain(pot, 
            system,
            Q=50.0, 
            T=T,
            num_chains=5, 
            adjoint=True,
            topology_update_freq=1).to(system.device)

    # define simulator with 
    sim = Simulations(system, diffeq)

    # define objects for the observables
    rdf_obs = rdf(system, nbins=nbins, r_range=rdf_range)
    vacf_obs = vacf(system, t_range=t_range) 
    
    all_vacf_sim = []

    # Run MD Simulations & Extract Target Data
    for i in range(n_sim):
        v_t, q_t, pv_t = sim.simulate(100, dt=dt, frequency=100) # 100 time steps, dt is given with argparse are in fs, freq is how often the system's state is updated and logged during a md simulation.

        if i >= skip:
            vacf_sim = vacf_obs(v_t).detach().cpu().numpy()
            all_vacf_sim.append(vacf_sim)
            
    # loop over to compute observables 
    trajs = torch.Tensor( np.stack( sim.log['positions'])).to(system.device).detach()
    all_g_sim = []
    for i in range(len(trajs)):

        if i >= skip:
            _, _, g_sim = rdf_obs(trajs[[i]])
            all_g_sim.append(g_sim.detach().cpu().numpy())

    all_g_sim = np.array(all_g_sim).mean(0)
    all_vacf_sim = np.array(all_vacf_sim).mean(0)
    
    return all_g_sim, all_vacf_sim
In [37]:
def get_observer(system, data_str, nbins, t_range, rdf_start):
    # get dt 
    dt = pair_data_dict[data_str].get('dt', 0.01)

    rdf_end = pair_data_dict[data_str].get("end", None)

    xnew = np.linspace(rdf_start , rdf_end, nbins)
    
    # initialize observable function 
    obs = rdf(system, nbins, (rdf_start , rdf_end) )
    vacf_obs = vacf(system, t_range=t_range) 

    # get experimental rdf 
    dim = pair_data_dict[data_str].get("dim", 3) 

    rdf_data_path = pair_data_dict[data_str].get("fn", None)
    # generate simulated data 
    if not rdf_data_path:
        rdf_data, vacf_target = get_target_obs(system, data_str, 200, (rdf_start, rdf_end), nbins=nbins, t_range=t_range, skip=50, dt=dt)
        vacf_target = torch.Tensor(vacf_target).to(system.device)
        rdf_data = np.vstack( (np.linspace(rdf_start, rdf_end, nbins), rdf_data))
    else:
        # experimental rdfs
        rdf_data = np.loadtxt(rdf_data_path, delimiter=',')
        vacf_target = None

    _, rdf_target = get_exp_rdf(rdf_data, nbins, (rdf_start, rdf_end), obs.device, dim=dim)

    # get model potential and simulate 

    return xnew, rdf_target, obs, vacf_target, vacf_obs
In [38]:
def get_sim(system, model, data_str, topology_update_freq=1):

    T = pair_data_dict[data_str]['T']

    diffeq = NoseHooverChain(model, 
            system,
            Q=50.0, 
            T=T,
            num_chains=5, 
            adjoint=True,
            topology_update_freq=topology_update_freq).to(system.device)

    # define simulator with 
    sim = Simulations(system, diffeq)

    return sim

Define Training Hyperparameters¶

– Set learning rate, epochs, and training options.

In [39]:
params = {
    'val': []
}
sys_params = {
    'val': params['val'],
}
In [78]:
n_epochs = 1000
n_sim = 200
size = 4
cutoff = 2.5
t_range = 50
device = 'cpu'

nbins = 100
tau = 60 #assignments['opt_freq']

rdf_start = 0.75
skip = 1

nbr_list_device = 'cpu'
topology_update_freq = 1

data_str_list = ['lj_0.3_1.2']

# Get the grounth truth pair potentials
target_pot = LennardJones()

# merge paramset a
if sys_params['val']:
    val_str_list = sys_params['val']
else:
    val_str_list = []

project_name = 'lj_sim'
suggestion_id = 'run0'
model_path = '{}/{}'.format(project_name, suggestion_id)
if not os.path.exists(model_path):
    os.makedirs(model_path)

print("Training for {} epochs".format(n_epochs))

train_vacf = True
Training for 1000 epochs

Initialize the MD System¶

– Set up atoms, boundary conditions, and simulation parameters

In [41]:
system_list = []
for data_str in data_str_list+val_str_list:
    system = get_system(data_str, device, size) 
    system_list.append(system)
lattice param: 2.371262202993375
/Users/francescapagano/miniconda3/envs/myenv_mlmd/lib/python3.9/site-packages/ase/md/md.py:52: FutureWarning: Specify the temperature in K using the 'temperature_K' argument
  warnings.warn(FutureWarning(w))

what is system ?¶

System is the starting point for the simulation:

In [48]:
from ase.visualize import view
view(system_list[0])
Out[48]:
<Popen: returncode: None args: ['/Users/francescapagano/miniconda3/envs/myen...>
2025-01-29 14:15:52.142 python[80563:3483549] +[IMKClient subclass]: chose IMKClient_Legacy
2025-01-29 14:15:52.142 python[80563:3483549] +[IMKInputSession subclass]: chose IMKInputSession_Legacy

Potential¶

  • LJ potential is the physics-informed prior potential
  • The neural network model (pairMLP) learns corrections to the prior potential

class LJFamily(torch.nn.Module): def init(self, sigma=1.0, epsilon=1.0, attr_pow=6, rep_pow=12): super(LJFamily, self).init() self.sigma = torch.nn.Parameter(torch.Tensor([sigma])) self.epsilon = torch.nn.Parameter(torch.Tensor([epsilon])) self.attr_pow = attr_pow self.rep_pow = rep_pow

def LJ(self, r, sigma, epsilon):
    return 4 * epsilon * ((sigma/r)**self.rep_pow - (sigma/r)**self.attr_pow)

def forward(self, x):
    return self.LJ(x, self.sigma, self.epsilon)

mlp_parmas = {'n_gauss': int(cutoff//0.10), # cutoff//gaussian_width 'r_start': 0.0, 'r_end': cutoff, 'n_width': 128, 'n_layers': 3, 'nonlinear': 'ELU'} lj_params = {'epsilon': 0.4, 'sigma': 0.9, "power": 10}

NN = pairMLP(mlp_parmas) pair = LJFamily(epsilon=2.0, sigma=0.9, rep_pow=6, attr_pow=3) # ExcludedVolume(lj_params)

model_list = [] for i, data_str in enumerate(data_str_list + val_str_list):

pairNN = PairPotentials(system_list[i], NN,
            cutoff=cutoff,
            nbr_list_device=nbr_list_device
            ).to(device)
prior = PairPotentials(system_list[i], pair,
                cutoff=2.5,
            nbr_list_device=nbr_list_device
                ).to(device)

model = Stack({'pairnn': pairNN, 'pair': prior})
model_list.append(model)

Stack combines:

  • A learned NN potential (pairnn)
  • A predefined Lennard-Jones potential (pair) The model will sum the outputs of pairnn and pair to compute the total energy.

Data Flow¶

Atomic Positions → Compute Pairwise Distances (r_ij) →
│
├──> Lennard-Jones Potential (pair) → U_LJ(r_ij)
│
├──> Neural Network Potential (pairnn)
│ ├── GaussianSmearing(r_ij)
│ ├── Feedforward NN (Linear + ELU Layers)
│ └── Output: U_MLP(r_ij)
│
└──> Final Potential: U_total = U_LJ + U_MLP → Sum Over Pairs → Compute System Energy
│
└──> Use for Molecular Dynamics (Forces, Simulations, Training)

In [149]:
def print_model_summary(model):
    print("Model Structure:")
    print("---------------")
    for name, submodel in model.models.items():
        print(f"\n{name}:")
        print("  Network architecture:")
        for child_name, child in submodel.named_children():
            print(f"    {child_name}: {child}")
        
        if hasattr(submodel, 'model'):
            print("\n  Potential network:")
            for layer_name, layer in submodel.model.named_children():
                print(f"    {layer_name}: {layer}")

# Use it on your model
print_model_summary(model)
Model Structure:
---------------

pairnn:
  Network architecture:
    model: pairMLP(
  (smear): GaussianSmearing()
  (layers): ModuleList(
    (0): Linear(in_features=24, out_features=24, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=24, out_features=128, bias=True)
    (3): ELU(alpha=1.0)
    (4): Linear(in_features=128, out_features=128, bias=True)
    (5): ELU(alpha=1.0)
    (6): Linear(in_features=128, out_features=128, bias=True)
    (7): ELU(alpha=1.0)
    (8): Linear(in_features=128, out_features=128, bias=True)
    (9): ELU(alpha=1.0)
    (10): Linear(in_features=128, out_features=24, bias=True)
    (11): ELU(alpha=1.0)
    (12): Linear(in_features=24, out_features=1, bias=True)
  )
)

  Potential network:
    smear: GaussianSmearing()
    layers: ModuleList(
  (0): Linear(in_features=24, out_features=24, bias=True)
  (1): ELU(alpha=1.0)
  (2): Linear(in_features=24, out_features=128, bias=True)
  (3): ELU(alpha=1.0)
  (4): Linear(in_features=128, out_features=128, bias=True)
  (5): ELU(alpha=1.0)
  (6): Linear(in_features=128, out_features=128, bias=True)
  (7): ELU(alpha=1.0)
  (8): Linear(in_features=128, out_features=128, bias=True)
  (9): ELU(alpha=1.0)
  (10): Linear(in_features=128, out_features=24, bias=True)
  (11): ELU(alpha=1.0)
  (12): Linear(in_features=24, out_features=1, bias=True)
)

pair:
  Network architecture:
    model: LJFamily()

  Potential network:
In [150]:
def count_parameters(model):
    total_params = 0
    trainable_params = 0
    for name, param in model.named_parameters():
        param_count = param.numel()
        total_params += param_count
        if param.requires_grad:
            trainable_params += param_count
        print(f"{name}: {list(param.shape)} ({param_count} parameters)")
    print(f"\nTotal parameters: {total_params}")
    print(f"Trainable parameters: {trainable_params}")

count_parameters(model)
models.pairnn.model.smear.width: [24] (24 parameters)
models.pairnn.model.smear.offsets: [24] (24 parameters)
models.pairnn.model.layers.0.weight: [24, 24] (576 parameters)
models.pairnn.model.layers.0.bias: [24] (24 parameters)
models.pairnn.model.layers.2.weight: [128, 24] (3072 parameters)
models.pairnn.model.layers.2.bias: [128] (128 parameters)
models.pairnn.model.layers.4.weight: [128, 128] (16384 parameters)
models.pairnn.model.layers.4.bias: [128] (128 parameters)
models.pairnn.model.layers.6.weight: [128, 128] (16384 parameters)
models.pairnn.model.layers.6.bias: [128] (128 parameters)
models.pairnn.model.layers.8.weight: [128, 128] (16384 parameters)
models.pairnn.model.layers.8.bias: [128] (128 parameters)
models.pairnn.model.layers.10.weight: [24, 128] (3072 parameters)
models.pairnn.model.layers.10.bias: [24] (24 parameters)
models.pairnn.model.layers.12.weight: [1, 24] (24 parameters)
models.pairnn.model.layers.12.bias: [1] (1 parameters)
models.pair.model.sigma: [1] (1 parameters)
models.pair.model.epsilon: [1] (1 parameters)

Total parameters: 56507
Trainable parameters: 56507
In [52]:
sim_list = [get_sim(system_list[i], 
                    model_list[i], 
                    data_str,
                    topology_update_freq=topology_update_freq) for i, data_str in enumerate(data_str_list + val_str_list)]
sim_list
Out[52]:
[<md.sim.Simulations at 0x128295b50>]

This cell initializes and collects observables for each molecular system in data_str_list + val_str_list. The goal is to compute and store Radial Distribution Function (RDF) and Velocity Auto-Correlation Function (VACF), both for sim and target.

In [54]:
rdf_obs_list = [] #Stores functions to compute RDF during training.
vacf_obs_list = [] #Stores functions to compute VACF during training

rdf_target_list = [] #Stores ground truth RDF data for comparison.
vacf_target_list = [] #Stores ground truth VACF data for comparison.
rdf_bins_list = [] #Stores distance bins for RDF computation

for i, data_str in enumerate(data_str_list + val_str_list):
    rdf_start = pair_data_dict[data_str].get("start", 0.75)
    x, rdf_target, rdf_obs, vacf_target, vacf_obs = get_observer(system_list[i],
                                                                    data_str, 
                                                                    nbins, 
                                                                    t_range=t_range,
                                                                    rdf_start=rdf_start)
    rdf_bins_list.append(x)

    rdf_obs_list.append(rdf_obs)
    rdf_target_list.append(rdf_target)
    vacf_obs_list.append(vacf_obs)
    vacf_target_list.append(vacf_target)
simulating lj_0.3_1.2
In [73]:
import torch
import matplotlib.pyplot as plt

# Convert tensor to NumPy array
rdf_target = rdf_target_list[0].cpu().numpy()  # Assuming first entry is needed
plt.figure(figsize=(8, 5))
plt.plot(rdf_bins_list[0], rdf_target, label="RDF Target", color='b', linewidth=2)

plt.xlabel("Distance (Å)")
plt.ylabel("g(r)")
plt.title("Radial Distribution Function (RDF)")
plt.legend()
plt.grid()
plt.show()
No description has been provided for this image
In [70]:
import torch
import matplotlib.pyplot as plt
time_bins = torch.arange(len(vacf_target)).cpu().numpy() * 0.005  # Assuming dt is known

# Convert tensor to NumPy array
vacf_target = vacf_target_list[0].cpu().numpy()  # Assuming first entry is needed
plt.figure(figsize=(8, 5))
plt.plot(time_bins,vacf_target, label="VACF Target", color='b', linewidth=2)


plt.xlabel("Time (fs)")
plt.ylabel("VACF")
plt.title("Velocity Auto-Correlation Function (VACF)")
plt.legend()
plt.grid()
plt.show()
No description has been provided for this image

Initialize Training Components¶

In [74]:
optimizer = torch.optim.Adam(list(NN.parameters()), lr=0.002)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 
                                                'min', 
                                                min_lr=1e-6, 
                                                verbose=True, factor = 0.5, patience= 20,
                                                threshold=5e-5)

# Set up simulations 
loss_log = []

# 
obs_log = dict()
/Users/francescapagano/miniconda3/envs/myenv_mlmd/lib/python3.9/site-packages/torch/optim/lr_scheduler.py:62: UserWarning: The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.
  warnings.warn(

Precompute Observables for Training¶

In [75]:
for i, data_str in enumerate(data_str_list + val_str_list):
    obs_log[data_str] = {}
    obs_log[data_str]['rdf'] = []
    obs_log[data_str]['vacf'] = []

Training¶

In [80]:
rdf_weight = 0.95
vacf_weight = 0.0
In [83]:
for i in range(n_epochs):

    loss_rdf = torch.Tensor([0.0]).to(device)
    loss_vacf = torch.Tensor([0.0]).to(device)

    n_train = len(data_str_list)
    
    for j, sim in enumerate(sim_list[:n_train]): 
        data_str = (data_str_list + val_str_list)[j]
        # get dt 
        dt = 0.005
        # Simulate 
        v_t, q_t, pv_t = sim.simulate(steps=tau, frequency=tau, dt=dt)
        
        if data_str in val_str_list:
            v_t = v_t.detach()
            q_t = q_t.detach()
            pv_t = pv_t.detach()

        if torch.isnan(q_t.reshape(-1)).sum().item() > 0:
            print("encounter NaN")
            print( 5 - (i / n_epochs) * 5 )
            break
        #_, _, g_sim = rdf_obs_list[j](q_t[::skip])
        # save memory by computing it in serial
        skip = 5
        n_frames = q_t[::skip].shape[0] 
        for idx in range(n_frames):
            if idx == 0:
                _, _, g_sim = rdf_obs_list[j](q_t[::skip][[idx]])
            else:
                g_sim += rdf_obs_list[j](q_t[::skip][[idx]])[2]

        g_sim = g_sim / n_frames

        # compute vacf 
        vacf_sim = vacf_obs_list[j](v_t)

        if data_str in data_str_list:
            if vacf_target_list[j] is not None:
                loss_vacf += (vacf_sim - vacf_target_list[j][:t_range]).pow(2).mean()
            else:
                loss_vacf += 0.0

            drdf = g_sim - rdf_target_list[j]
            loss_rdf += (drdf).pow(2).mean()#+ JS_rdf(g_sim, rdf_target_list[j])

        obs_log[data_str]['rdf'].append(g_sim.detach().cpu().numpy())
        obs_log[data_str]['vacf'].append(vacf_sim.detach().cpu().numpy())

        if i % 5 ==0 :
            if vacf_target_list[j] is not None:
                vacf_target = vacf_target_list[j][:t_range].detach().cpu().numpy()
            else:
                vacf_target = None
            rdf_target = rdf_target_list[j].detach().cpu().numpy()

            plot_vacf(vacf_sim.detach().cpu().numpy(), vacf_target, 
                fn=data_str + "_{}".format(str(i).zfill(3)), 
                dt=dt,
                path=model_path)

            plot_rdf(g_sim.detach().cpu().numpy(), rdf_target, 
                fn=data_str + "_{}".format(str(i).zfill(3)),
                    path=model_path, 
                    start=rdf_start, 
                    nbins=nbins,
                    end=rdf_obs_list[j].r_axis[-1])

        if i % 5 ==0 :
            potential = plot_pair( path=model_path,
                            fn=str(i).zfill(3),
                            model=sim.integrator.model.models['pairnn'].model, 
                            prior=sim.integrator.model.models['pair'].model, 
                            device=device,
                            target_pot=target_pot.to(device),
                            end=cutoff)

    if train_vacf == "True":
        loss = rdf_weight * loss_rdf +  vacf_weight * loss_vacf
    else:
        loss = rdf_weight * loss_rdf

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()
    
    print(loss_vacf.item(), loss_rdf.item())
    
    scheduler.step(loss)
    
    loss_log.append([loss_vacf.item(), loss_rdf.item() ])

    current_lr = optimizer.param_groups[0]["lr"]

    if current_lr <= 1e-5:
        print("training converged")
        break

    np.savetxt(model_path + '/loss.txt', np.array(loss_log), delimiter=',')

# # save potentials         
# if np.array(loss_log)[-10:, 1].mean() <=  0.005: 
#     np.savetxt(model_path + '/potential.txt',  potential, delimiter=',')
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.04633055999875069 0.010043300688266754
0.20236946642398834 0.05829494819045067
0.13431745767593384 0.11968377977609634
0.06419174373149872 0.16333819925785065
0.038015469908714294 0.14341723918914795
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.02930583991110325 0.09158311784267426
0.04251861944794655 0.090910904109478
0.03419381007552147 0.14977754652500153
0.01796233467757702 0.15637627243995667
0.01547524705529213 0.09807882457971573
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.04618494585156441 0.05017462372779846
0.11627700179815292 0.021105671301484108
0.043983835726976395 0.06279174983501434
0.030113138258457184 0.1281590461730957
0.029843416064977646 0.1573271006345749
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.019405346363782883 0.227510467171669
0.024593450129032135 0.2552144527435303
0.018196361139416695 0.29571297764778137
0.02855883166193962 0.2951563894748688
0.021850736811757088 0.2793882489204407
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.038046326488256454 0.27704596519470215
0.02740645594894886 0.18910230696201324
0.017252592369914055 0.1244056448340416
0.020001379773020744 0.07963059842586517
0.011261511594057083 0.062042560428380966
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.012782521545886993 0.032121773809194565
0.02184639871120453 0.029230860993266106
0.032829657196998596 0.021662866696715355
0.06968507915735245 0.030158495530486107
0.043725598603487015 0.04170085862278938
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.02225402556359768 0.03344520181417465
0.0833972617983818 0.02449600212275982
0.03217475116252899 0.018466275185346603
0.03851333260536194 0.021241605281829834
0.12616121768951416 0.015991264954209328
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.0511041060090065 0.009908926673233509
0.03359217569231987 0.014169629663228989
0.0743744894862175 0.004685154650360346
0.1356428861618042 0.012889456003904343
0.09922919422388077 0.008808398619294167
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.05750194191932678 0.010880760848522186
0.05852023512125015 0.010412020608782768
0.08340760320425034 0.005074560642242432
0.05566220358014107 0.008761439472436905
0.03940587118268013 0.010030992329120636
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.06671109795570374 0.010998988524079323
0.09722784161567688 0.013169880025088787
0.04666656628251076 0.014039542526006699
0.0266016386449337 0.00906712468713522
0.060235146433115005 0.0075272200629115105
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.09385934472084045 0.014995615929365158
0.051053486764431 0.007780490908771753
0.05436474829912186 0.004648578353226185
0.06913043558597565 0.006592984311282635
0.051048196852207184 0.005015496630221605
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.05617053434252739 0.003883780911564827
0.06154339015483856 0.007360624615103006
0.012405047193169594 0.007416481152176857
0.0717056542634964 0.004207460209727287
0.0799887627363205 0.004569487180560827
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.06424430012702942 0.0055542862974107265
0.03803030773997307 0.008251429535448551
0.06833816319704056 0.006393061950802803
0.0764475092291832 0.003876088885590434
0.044118452817201614 0.005164355039596558
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.06917167454957962 0.003544821171090007
0.10460229963064194 0.004990475717931986
0.061822351068258286 0.006077866535633802
0.036580268293619156 0.006292472593486309
0.07136891782283783 0.004308749921619892
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.1084817498922348 0.0019340902799740434
0.05874348059296608 0.0037314030341804028
0.05043186992406845 0.0024738553911447525
0.09408694505691528 0.0021624243818223476
0.09122275561094284 0.0038339602760970592
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.030772924423217773 0.008911254815757275
0.034080762416124344 0.005568155087530613
0.08209067583084106 0.0024854487273842096
0.0682661160826683 0.003161157015711069
0.044390853494405746 0.0038394152652472258
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.07547518610954285 0.001764682587236166
0.07235253602266312 0.004831680096685886
0.09986728429794312 0.004434170667082071
0.04135961830615997 0.0032481388188898563
0.09872177243232727 0.0033892521169036627
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.07601369172334671 0.0022193407639861107
0.08489355444908142 0.0016392340185120702
0.026955999433994293 0.0028773536905646324
0.028381120413541794 0.003647969337180257
0.0642792284488678 0.003932399675250053
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.08144722878932953 0.0026279992889612913
0.03627745434641838 0.003960017580538988
0.09532791376113892 0.002152936765924096
0.08060938119888306 0.002907180693000555
0.029194677248597145 0.0026513785123825073
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.042677298188209534 0.004090467467904091
0.09880457818508148 0.004020046442747116
0.08512777090072632 0.0034153214655816555
0.06658999621868134 0.0025431702379137278
0.03461385518312454 0.004609052557498217
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.05919656157493591 0.0031950033735483885
0.07202740013599396 0.003467089729383588
0.03581404313445091 0.00223873695358634
0.04688747972249985 0.003939190413802862
0.07860332727432251 0.002299464540556073
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.06195599213242531 0.0036474221851676702
0.062103915959596634 0.002677328884601593
0.019062094390392303 0.003090736223384738
0.052140358835458755 0.003462327178567648
0.07102754712104797 0.003356299363076687
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.08360378444194794 0.0033088738564401865
0.03269742801785469 0.0019728736951947212
0.10011973232030869 0.001103093265555799
0.05196038633584976 0.0020374611485749483
0.03210205212235451 0.003018932417035103
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.06971368938684464 0.001889049424789846
0.08849235624074936 0.002132921712473035
0.047778040170669556 0.001716303057037294
0.019058817997574806 0.0019802204333245754
0.07318150252103806 0.0015941888559609652
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.06338120996952057 0.0025376901030540466
0.06087113544344902 0.003344047348946333
0.05560697987675667 0.0024015139788389206
0.03515675663948059 0.002795663196593523
0.02359706163406372 0.004246007651090622
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.06295900046825409 0.0037083528004586697
0.024573376402258873 0.0038052713498473167
0.023760592564940453 0.005958267953246832
0.051972780376672745 0.0034134129527956247
0.0634133592247963 0.004205169156193733
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.03414866328239441 0.004493663087487221
0.04584605246782303 0.003998476546257734
0.01875460147857666 0.0039401911199092865
0.0831926092505455 0.0030581483151763678
0.09581798315048218 0.0015265141846612096
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.06330931186676025 0.0038221694994717836
0.07825545966625214 0.002808419754728675
0.1282954216003418 0.003078846260905266
0.042410045862197876 0.0025002574548125267
0.018115798011422157 0.003285877639427781
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.03828461095690727 0.0022987229749560356
0.10932434350252151 0.0019829105585813522
0.05933761224150658 0.001307878759689629
0.05370749533176422 0.003916655667126179
0.04825720936059952 0.0021974476985633373
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.08934042602777481 0.0010675211669877172
0.037287574261426926 0.003348633646965027
0.01513825636357069 0.004311643075197935
0.053105417639017105 0.0013432854320853949
0.12009597569704056 0.002277851104736328
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.050818633288145065 0.004724223166704178
0.03100726567208767 0.001874067005701363
0.09133799374103546 0.0011570833157747984
0.057620562613010406 0.001494632801041007
0.06150764599442482 0.0023313150741159916
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.08368604630231857 0.002848237520083785
0.05097782984375954 0.0016704823356121778
0.03992683067917824 0.0010176339419558644
0.025034936144948006 0.002109982306137681
0.02689593844115734 0.002892452524974942
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.07405448704957962 0.0034658554941415787
0.08389942348003387 0.0028659505769610405
0.02560444548726082 0.002557440660893917
0.03423386439681053 0.003203923348337412
0.054591938853263855 0.002834795042872429
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.022573962807655334 0.0034451084211468697
0.05296829715371132 0.0032701469026505947
0.05343877896666527 0.0017493261257186532
0.04474444314837456 0.002562630223110318
0.0907260999083519 0.0015969823580235243
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.020530246198177338 0.003802410326898098
0.024552499875426292 0.0041845072992146015
0.05318091809749603 0.003999505192041397
0.10100594162940979 0.0021584522910416126
0.06413844227790833 0.0015778415836393833
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.04441334307193756 0.002070643939077854
0.04549916088581085 0.005427600350230932
0.06794867664575577 0.003029446816071868
0.019710682332515717 0.002855653176084161
0.042121294885873795 0.005036984104663134
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.08018523454666138 0.005324619356542826
0.09015438705682755 0.00457659550011158
0.056581366807222366 0.004514981992542744
0.03990184888243675 0.005397412925958633
0.03853808715939522 0.0035117920488119125
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.028920775279402733 0.002667040564119816
0.08216548711061478 0.006145902443677187
0.07655035704374313 0.0020489783491939306
0.036234281957149506 0.002626583678647876
0.042886897921562195 0.00234964513219893
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.06521777808666229 0.0015487339114770293
0.07991483062505722 0.0015492331003770232
0.05614198371767998 0.005295565351843834
0.057382117956876755 0.0019163308897987008
0.07571377605199814 0.002422511810436845
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.06951935589313507 0.0031423550099134445
0.09660250693559647 0.0018593923887237906
0.07215885818004608 0.0025094132870435715
0.03378266841173172 0.0032088959123939276
0.06171073019504547 0.0029725623317062855
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.07213623821735382 0.0014286652440205216
0.038524121046066284 0.0022366081830114126
0.04296727105975151 0.0018621280323714018
0.06337431073188782 0.0011346233077347279
0.07453681528568268 0.0016184784471988678
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.027122242376208305 0.003590609412640333
0.06991582363843918 0.004532129503786564
0.04346119984984398 0.003347467165440321
0.024160807952284813 0.0023364608641713858
0.056272830814123154 0.00215146504342556
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.05117037892341614 0.0027367942966520786
0.07315518707036972 0.0031932187266647816
0.04946345463395119 0.0032869549468159676
0.02238074690103531 0.003056967630982399
0.06628426164388657 0.0038786500226706266
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.031161140650510788 0.0028349151834845543
0.013062585145235062 0.0011921291006729007
0.06554070115089417 0.002175277331843972
0.09362849593162537 0.0021663475781679153
0.030331695452332497 0.002279676264151931
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.048832572996616364 0.0023089454043656588
0.03622838109731674 0.0037944025825709105
0.04960097000002861 0.0022481100168079138
0.05980388820171356 0.002241757931187749
0.06918107718229294 0.002916302066296339
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.03470796346664429 0.0018953901017084718
0.04247225448489189 0.0028281821869313717
0.10363960266113281 0.002418776974081993
0.055149246007204056 0.003054825123399496
0.06060966104269028 0.0014402479864656925
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.04435794800519943 0.0018534105038270354
0.0733352079987526 0.0041503701359033585
0.0321279801428318 0.001898439833894372
0.03973294794559479 0.0020321100018918514
0.08100633323192596 0.002132127992808819
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.04649118334054947 0.0037992498837411404
0.020513296127319336 0.002808330114930868
0.05340828374028206 0.0035707640927284956
0.07799831032752991 0.0023940899409353733
0.03868403285741806 0.001460537314414978
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.06184737756848335 0.001812399597838521
0.147007554769516 0.0011774987215176225
0.04845963045954704 0.0019442903576418757
0.026379775255918503 0.002369264140725136
0.06107716262340546 0.001900967094115913
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.07491497695446014 0.0036729336716234684
0.0637986958026886 0.002132187830284238
0.1180073544383049 0.0033266174141317606
0.06005686894059181 0.002241848036646843
0.06713001430034637 0.0026064224075526
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.07080326229333878 0.002727324375882745
0.06491228938102722 0.0025215521454811096
0.05090469494462013 0.0018093108665198088
0.05472536012530327 0.0030128019861876965
0.06709842383861542 0.004451180808246136
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.0735495537519455 0.0019546186085790396
0.05845671519637108 0.0013967320555821061
0.04122907295823097 0.0018729055300354958
0.06740739941596985 0.003240782069042325
0.07806723564863205 0.002915367716923356
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
0.03435903787612915 0.0025256327353417873
0.030306274071335793 0.001933713210746646
0.09341749548912048 0.002114050555974245
training converged
In [ ]:
# # save potentials         
# if np.array(loss_log)[-10:, 1].mean() <=  0.005: 
#     np.savetxt(model_path + '/potential.txt',  potential, delimiter=',')

rdf_dev = []

for j, sim in enumerate(sim_list):

    #simulate with no optimization
    data_str = (data_str_list + val_str_list)[j]

    dt = 0.005

    all_vacf_sim = []

    for i in range(n_sim):
        v_t, q_t, pv_t = sim.simulate(steps=tau, frequency=tau, dt=dt)

        # compute VACF 
        vacf_sim = vacf_obs_list[j](v_t).detach().cpu().numpy()
        all_vacf_sim.append(vacf_sim)

    all_vacf_sim = np.array(all_vacf_sim).mean(0)
    
    trajs = torch.Tensor( np.stack( sim.log['positions'])).to(system.device).detach()

    # get targets
    if vacf_target_list[j] is not None:
        vacf_target = vacf_target_list[j][:t_range].detach().cpu().numpy()
    else:
        vacf_target = None
    rdf_target = rdf_target_list[j].detach().cpu().numpy()
    

    # loop over to ocmpute observables 
    all_g_sim = []
    for i in range(len(trajs)):
        _, _, g_sim = rdf_obs_list[j](trajs[[i]])
        all_g_sim.append(g_sim.detach().cpu().numpy())

    all_g_sim = np.array(all_g_sim).mean(0)
    
    # compute target deviation 
    if data_str in data_str_list:
        drdf = np.abs(all_g_sim - rdf_target_list[j].cpu().numpy()).mean()
        rdf_dev.append(drdf) 

    # plot observables 
    plot_vacf(all_vacf_sim, vacf_target, 
        fn=data_str, 
        path=model_path,
        dt=dt,
        save_data=True)

    plot_rdf(all_g_sim, rdf_target, 
        fn=data_str,
            path=model_path, 
            start=rdf_start, 
            nbins=nbins,
            save_data=True,
            end=rdf_obs_list[j].r_axis[-1])


    # rdf_dev = np.abs(all_g_sim - rdf_target).mean()

Evaluation¶

with save traj

In [84]:
# Save the learned potential if the final RDF loss is below threshold
if np.array(loss_log)[-10:, 1].mean() <= 0.005: 
    potential = plot_pair(path=model_path,
                          fn="final",
                          model=sim.integrator.model.models['pairnn'].model, 
                          prior=sim.integrator.model.models['pair'].model, 
                          device=device,
                          target_pot=target_pot.to(device),
                          end=cutoff)
    
    np.savetxt(model_path + '/potential.txt', potential, delimiter=',')

rdf_dev = []
import ase
import ase.io

# Function to save trajectory as XYZ file
def save_traj(system, traj, fname, skip=10):
    """
    Saves atomic trajectories as an XYZ file for visualization.

    Parameters:
        system: ASE Atoms object representing the molecular system.
        traj: List of atomic positions at different timesteps.
        fname: Filename for saving the trajectory.
        skip: Interval for saving frames (reduces file size).
    """
    atoms_list = []
    for i, frame in enumerate(traj):
        if i % skip == 0:  # Save every 'skip' frames to reduce file size
            frame = ase.Atoms(positions=frame, numbers=system.get_atomic_numbers())
            atoms_list.append(frame)
    ase.io.write(fname, atoms_list)  # Save as XYZ file

# Final evaluation loop
for j, sim in enumerate(sim_list):

    # Simulate with the trained model (without optimization)
    data_str = (data_str_list + val_str_list)[j]
    train_traj = sim.log['positions']

    # Save training trajectory
    save_traj(system_list[j], train_traj, model_path + '/{}_train.xyz'.format(data_str), skip=10)

    dt = 0.005
    all_vacf_sim = []

    for i in range(n_sim):
        v_t, q_t, pv_t = sim.simulate(steps=tau, frequency=tau, dt=dt)

        # Compute VACF
        vacf_sim = vacf_obs_list[j](v_t).detach().cpu().numpy()
        all_vacf_sim.append(vacf_sim)

    all_vacf_sim = np.array(all_vacf_sim).mean(0)
    
    trajs = torch.Tensor(np.stack(sim.log['positions'])).to(system.device).detach()

    # Get target RDF and VACF
    if vacf_target_list[j] is not None:
        vacf_target = vacf_target_list[j][:t_range].detach().cpu().numpy()
    else:
        vacf_target = None
    rdf_target = rdf_target_list[j].detach().cpu().numpy()
    
    # Compute RDF over all saved trajectories
    all_g_sim = []
    for i in range(len(trajs)):
        _, _, g_sim = rdf_obs_list[j](trajs[[i]])
        all_g_sim.append(g_sim.detach().cpu().numpy())

    all_g_sim = np.array(all_g_sim).mean(0)
    
    # Compute target deviation
    if data_str in data_str_list:
        drdf = np.abs(all_g_sim - rdf_target_list[j].cpu().numpy()).mean()
        rdf_dev.append(drdf) 

    # Plot and save observables (VACF and RDF)
    plot_vacf(all_vacf_sim, vacf_target, 
              fn=data_str, 
              path=model_path,
              dt=dt,
              save_data=True)

    plot_rdf(all_g_sim, rdf_target, 
             fn=data_str,
             path=model_path, 
             start=rdf_start, 
             nbins=nbins,
             save_data=True,
             end=rdf_obs_list[j].r_axis[-1])

    # Save final trajectory (inference trajectory)
    save_traj(system_list[j], np.stack(sim.log['positions']),  
              model_path + '/{}_sim.xyz'.format(data_str), skip=1)
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [85]:
np.savetxt(model_path + '/potential.txt',  potential, delimiter=',')
np.savetxt(model_path + '/rdf_dev.txt', np.array(rdf_dev), delimiter=',')


# save loss curve 
plt.plot(np.array( loss_log)[:, 0], label='vacf', alpha=0.7)
plt.plot(np.array( loss_log)[:, 1], label='rdf', alpha=0.7)
plt.yscale("log")
plt.legend()

plt.savefig(model_path + '/loss.pdf', bbox_inches='tight')
plt.show()
plt.close()

print( np.array(loss_log)[-10:, 1].mean()  )
No description has been provided for this image
0.0025417785276658834
In [ ]: